{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using CuTe Layout Algebra With Python DSL\n",
    "\n",
    "Referencing the [01_layout.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md) and [02_layout_algebra.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) documentation from CuTe C++, we summarize:\n",
    "\n",
    "In CuTe, a `Layout`:\n",
    "- is defined by a pair of `Shape` and `Stride`,\n",
    "- maps coordinates space(s) to an index space,\n",
    "- supports both static (compile-time) and dynamic (runtime) values.\n",
    "\n",
    "CuTe also provides a powerful set of operations—the *Layout Algebra*—for combining and manipulating layouts, including:\n",
    "- Layout composition: Functional composition of layouts,\n",
    "- Layout \"divide\": Splitting a layout into two component layouts,\n",
    "- Layout \"product\": Reproducing a layout according to another layout.\n",
    "\n",
    "In this notebook, we will demonstrate:\n",
    "1. How to use CuTe’s key layout algebra operations with the Python DSL.\n",
    "2. How static and dynamic layouts behave when printed or manipulated within the Python DSL.\n",
    "\n",
    "We use examples from [02_layout_algebra.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) which we recommend to the reader for additional details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cutlass\n",
    "import cutlass.cute as cute"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layout Algebra Operations\n",
    "\n",
    "These operations form the foundation of CuTe's layout manipulation capabilities, enabling:\n",
    "- Efficient data tiling and partitioning,\n",
    "- Separation of thread and data layouts with a canonical type to represent both,\n",
    "- Native description and manipulation of hierarchical tensors of threads and data crucial for tensor core programs,\n",
    "- Mixed static/dynamic layout transformations,\n",
    "- Seamless integration of layout algebra with tensor operations,\n",
    "- Expression of complex MMA and copies as canonical loops."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Coalesce\n",
    "\n",
    "The `coalesce` operation simplifies a layout by flattening and combining modes when possible, without changing its size or behavior as a function on the integers.\n",
    "\n",
    "It ensures the post-conditions:\n",
    "- Preserve size: cute.size(layout) == cute.size(result),\n",
    "- Flattened: depth(result) <= 1,\n",
    "- Preserve functional: For all i, 0 <= i < cute.size(layout), layout(i) == result(i).\n",
    "\n",
    "#### Examples\n",
    "\n",
    "- Basic Coalesce Example :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Original: (2,(1,6)):(1,(?,2))\n",
      ">>> Coalesced: 12:1\n",
      ">?? Original: (2,(1,6)):(1,(6,2))\n",
      ">?? Coalesced: 12:1\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def coalesce_example():\n",
    "    \"\"\"\n",
    "    Demonstrates coalesce operation flattening and combining modes\n",
    "    \"\"\"\n",
    "    layout = cute.make_layout((2, (1, 6)), stride=(1, (cutlass.Int32(6), 2))) # Dynamic stride\n",
    "    result = cute.coalesce(layout)\n",
    "\n",
    "    print(\">>> Original:\", layout)\n",
    "    cute.printf(\">?? Original: {}\", layout)\n",
    "    print(\">>> Coalesced:\", result)\n",
    "    cute.printf(\">?? Coalesced: {}\", result)\n",
    "\n",
    "coalesce_example()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Original: ((2,(3,4)),(3,2),1):((4,(8,24)),(2,6),12)\n",
      ">>> Coalesced: (24,6):(4,2)\n",
      ">>> Checking post-conditions:\n",
      ">>> 1. Checking size remains the same after the coalesce operation:\n",
      "Original size: 144, Coalesced size: 144\n",
      ">>> 2. Checking depth of coalesced layout <= 1:\n",
      "Depth of coalesced layout: 1\n",
      ">>> 3. Checking layout functionality remains the same after the coalesce operation:\n",
      "Index 0: original 0, coalesced 0\n",
      "Index 1: original 4, coalesced 4\n",
      "Index 2: original 8, coalesced 8\n",
      "Index 3: original 12, coalesced 12\n",
      "Index 4: original 16, coalesced 16\n",
      "Index 5: original 20, coalesced 20\n",
      "Index 6: original 24, coalesced 24\n",
      "Index 7: original 28, coalesced 28\n",
      "Index 8: original 32, coalesced 32\n",
      "Index 9: original 36, coalesced 36\n",
      "Index 10: original 40, coalesced 40\n",
      "Index 11: original 44, coalesced 44\n",
      "Index 12: original 48, coalesced 48\n",
      "Index 13: original 52, coalesced 52\n",
      "Index 14: original 56, coalesced 56\n",
      "Index 15: original 60, coalesced 60\n",
      "Index 16: original 64, coalesced 64\n",
      "Index 17: original 68, coalesced 68\n",
      "Index 18: original 72, coalesced 72\n",
      "Index 19: original 76, coalesced 76\n",
      "Index 20: original 80, coalesced 80\n",
      "Index 21: original 84, coalesced 84\n",
      "Index 22: original 88, coalesced 88\n",
      "Index 23: original 92, coalesced 92\n",
      "Index 24: original 2, coalesced 2\n",
      "Index 25: original 6, coalesced 6\n",
      "Index 26: original 10, coalesced 10\n",
      "Index 27: original 14, coalesced 14\n",
      "Index 28: original 18, coalesced 18\n",
      "Index 29: original 22, coalesced 22\n",
      "Index 30: original 26, coalesced 26\n",
      "Index 31: original 30, coalesced 30\n",
      "Index 32: original 34, coalesced 34\n",
      "Index 33: original 38, coalesced 38\n",
      "Index 34: original 42, coalesced 42\n",
      "Index 35: original 46, coalesced 46\n",
      "Index 36: original 50, coalesced 50\n",
      "Index 37: original 54, coalesced 54\n",
      "Index 38: original 58, coalesced 58\n",
      "Index 39: original 62, coalesced 62\n",
      "Index 40: original 66, coalesced 66\n",
      "Index 41: original 70, coalesced 70\n",
      "Index 42: original 74, coalesced 74\n",
      "Index 43: original 78, coalesced 78\n",
      "Index 44: original 82, coalesced 82\n",
      "Index 45: original 86, coalesced 86\n",
      "Index 46: original 90, coalesced 90\n",
      "Index 47: original 94, coalesced 94\n",
      "Index 48: original 4, coalesced 4\n",
      "Index 49: original 8, coalesced 8\n",
      "Index 50: original 12, coalesced 12\n",
      "Index 51: original 16, coalesced 16\n",
      "Index 52: original 20, coalesced 20\n",
      "Index 53: original 24, coalesced 24\n",
      "Index 54: original 28, coalesced 28\n",
      "Index 55: original 32, coalesced 32\n",
      "Index 56: original 36, coalesced 36\n",
      "Index 57: original 40, coalesced 40\n",
      "Index 58: original 44, coalesced 44\n",
      "Index 59: original 48, coalesced 48\n",
      "Index 60: original 52, coalesced 52\n",
      "Index 61: original 56, coalesced 56\n",
      "Index 62: original 60, coalesced 60\n",
      "Index 63: original 64, coalesced 64\n",
      "Index 64: original 68, coalesced 68\n",
      "Index 65: original 72, coalesced 72\n",
      "Index 66: original 76, coalesced 76\n",
      "Index 67: original 80, coalesced 80\n",
      "Index 68: original 84, coalesced 84\n",
      "Index 69: original 88, coalesced 88\n",
      "Index 70: original 92, coalesced 92\n",
      "Index 71: original 96, coalesced 96\n",
      "Index 72: original 6, coalesced 6\n",
      "Index 73: original 10, coalesced 10\n",
      "Index 74: original 14, coalesced 14\n",
      "Index 75: original 18, coalesced 18\n",
      "Index 76: original 22, coalesced 22\n",
      "Index 77: original 26, coalesced 26\n",
      "Index 78: original 30, coalesced 30\n",
      "Index 79: original 34, coalesced 34\n",
      "Index 80: original 38, coalesced 38\n",
      "Index 81: original 42, coalesced 42\n",
      "Index 82: original 46, coalesced 46\n",
      "Index 83: original 50, coalesced 50\n",
      "Index 84: original 54, coalesced 54\n",
      "Index 85: original 58, coalesced 58\n",
      "Index 86: original 62, coalesced 62\n",
      "Index 87: original 66, coalesced 66\n",
      "Index 88: original 70, coalesced 70\n",
      "Index 89: original 74, coalesced 74\n",
      "Index 90: original 78, coalesced 78\n",
      "Index 91: original 82, coalesced 82\n",
      "Index 92: original 86, coalesced 86\n",
      "Index 93: original 90, coalesced 90\n",
      "Index 94: original 94, coalesced 94\n",
      "Index 95: original 98, coalesced 98\n",
      "Index 96: original 8, coalesced 8\n",
      "Index 97: original 12, coalesced 12\n",
      "Index 98: original 16, coalesced 16\n",
      "Index 99: original 20, coalesced 20\n",
      "Index 100: original 24, coalesced 24\n",
      "Index 101: original 28, coalesced 28\n",
      "Index 102: original 32, coalesced 32\n",
      "Index 103: original 36, coalesced 36\n",
      "Index 104: original 40, coalesced 40\n",
      "Index 105: original 44, coalesced 44\n",
      "Index 106: original 48, coalesced 48\n",
      "Index 107: original 52, coalesced 52\n",
      "Index 108: original 56, coalesced 56\n",
      "Index 109: original 60, coalesced 60\n",
      "Index 110: original 64, coalesced 64\n",
      "Index 111: original 68, coalesced 68\n",
      "Index 112: original 72, coalesced 72\n",
      "Index 113: original 76, coalesced 76\n",
      "Index 114: original 80, coalesced 80\n",
      "Index 115: original 84, coalesced 84\n",
      "Index 116: original 88, coalesced 88\n",
      "Index 117: original 92, coalesced 92\n",
      "Index 118: original 96, coalesced 96\n",
      "Index 119: original 100, coalesced 100\n",
      "Index 120: original 10, coalesced 10\n",
      "Index 121: original 14, coalesced 14\n",
      "Index 122: original 18, coalesced 18\n",
      "Index 123: original 22, coalesced 22\n",
      "Index 124: original 26, coalesced 26\n",
      "Index 125: original 30, coalesced 30\n",
      "Index 126: original 34, coalesced 34\n",
      "Index 127: original 38, coalesced 38\n",
      "Index 128: original 42, coalesced 42\n",
      "Index 129: original 46, coalesced 46\n",
      "Index 130: original 50, coalesced 50\n",
      "Index 131: original 54, coalesced 54\n",
      "Index 132: original 58, coalesced 58\n",
      "Index 133: original 62, coalesced 62\n",
      "Index 134: original 66, coalesced 66\n",
      "Index 135: original 70, coalesced 70\n",
      "Index 136: original 74, coalesced 74\n",
      "Index 137: original 78, coalesced 78\n",
      "Index 138: original 82, coalesced 82\n",
      "Index 139: original 86, coalesced 86\n",
      "Index 140: original 90, coalesced 90\n",
      "Index 141: original 94, coalesced 94\n",
      "Index 142: original 98, coalesced 98\n",
      "Index 143: original 102, coalesced 102\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def coalesce_post_conditions():\n",
    "    \"\"\"\n",
    "    Demonstrates coalesce operation's 3 post-conditions:\n",
    "    1. size(@a result) == size(@a layout)\n",
    "    2. depth(@a result) <= 1\n",
    "    3. for all i, 0 <= i < size(@a layout), @a result(i) == @a layout(i)\n",
    "    \"\"\"\n",
    "    layout = cute.make_layout(\n",
    "        ((2, (3, 4)), (3, 2), 1),\n",
    "        stride=((4, (8, 24)), (2, 6), 12)\n",
    "    )\n",
    "    result = cute.coalesce(layout)\n",
    "\n",
    "    print(\">>> Original:\", layout)\n",
    "    print(\">>> Coalesced:\", result)\n",
    "\n",
    "    print(\">>> Checking post-conditions:\")\n",
    "    print(\">>> 1. Checking size remains the same after the coalesce operation:\")\n",
    "    original_size = cute.size(layout)\n",
    "    coalesced_size = cute.size(result)\n",
    "    print(f\"Original size: {original_size}, Coalesced size: {coalesced_size}\")\n",
    "    assert coalesced_size == original_size, \\\n",
    "            f\"Size mismatch: original {original_size}, coalesced {coalesced_size}\"\n",
    "    \n",
    "    print(\">>> 2. Checking depth of coalesced layout <= 1:\")\n",
    "    depth = cute.depth(result)\n",
    "    print(f\"Depth of coalesced layout: {depth}\")\n",
    "    assert depth <= 1, f\"Depth of coalesced layout should be <= 1, got {depth}\"\n",
    "\n",
    "    print(\">>> 3. Checking layout functionality remains the same after the coalesce operation:\")\n",
    "    for i in cutlass.range_constexpr(original_size):\n",
    "        original_value = layout(i)\n",
    "        coalesced_value = result(i)\n",
    "        print(f\"Index {i}: original {original_value}, coalesced {coalesced_value}\")\n",
    "        assert coalesced_value == original_value, \\\n",
    "            f\"Value mismatch at index {i}: original {original_value}, coalesced {coalesced_value}\"\n",
    "\n",
    "coalesce_post_conditions()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- By-mode Coalesce Example :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Original:  (2,(1,6)):(1,(6,2))\n",
      ">>> Coalesced Result:  (2,6):(1,2)\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def bymode_coalesce_example():\n",
    "    \"\"\"\n",
    "    Demonstrates by-mode coalescing\n",
    "    \"\"\"\n",
    "    layout = cute.make_layout((2, (1, 6)), stride=(1, (6, 2)))\n",
    "\n",
    "    # Coalesce with mode-wise profile (1,1) = coalesce both modes\n",
    "    result = cute.coalesce(layout, target_profile=(1, 1))\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Original: \", layout)\n",
    "    print(\">>> Coalesced Result: \", result)\n",
    "\n",
    "bymode_coalesce_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Composition\n",
    "\n",
    "`Composition` of Layout `A` with Layout `B` creates a new layout `R = A ◦ B` where:\n",
    "- The shape of `B` is compatible with the shape of `R` so that all coordinates of `B` can also be used as coordinates of `R`,\n",
    "- `R(c) = A(B(c))` for all coordinates `c` in `B`'s domain.\n",
    "\n",
    "Layout composition is very useful for reshaping and reordering layouts.\n",
    "\n",
    "#### Examples\n",
    "\n",
    "- Basic Composition Example :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout A: (6,2):(?,2)\n",
      ">>> Layout B: (4,3):(3,1)\n",
      ">>> Composition R = A ◦ B: ((2,2),3):((?{div=3},2),?)\n",
      ">?? Layout A: (6,2):(8,2)\n",
      ">?? Layout B: (4,3):(3,1)\n",
      ">?? Composition R: ((2,2),3):((24,2),8)\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def composition_example():\n",
    "    \"\"\"\n",
    "    Demonstrates basic layout composition R = A ◦ B\n",
    "    \"\"\"\n",
    "    A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2)) # Dynamic stride\n",
    "    B = cute.make_layout((4, 3), stride=(3, 1))\n",
    "    R = cute.composition(A, B)\n",
    "\n",
    "    # Print static and dynamic information\n",
    "    print(\">>> Layout A:\", A)\n",
    "    cute.printf(\">?? Layout A: {}\", A)\n",
    "    print(\">>> Layout B:\", B) \n",
    "    cute.printf(\">?? Layout B: {}\", B)\n",
    "    print(\">>> Composition R = A ◦ B:\", R)\n",
    "    cute.printf(\">?? Composition R: {}\", R)\n",
    "\n",
    "composition_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Comparing Composition with static and dynamic layouts :\n",
    "\n",
    "In this case, the results may look different but are mathematically the same. The 1s in the shape don't affect the layout as a mathematical function on the integers. In the dynamic case, CuTe can not coalesce the dynamic size-1 modes to \"simplify\" the layout because it is not valid to do so for all possible dynamic values that parameter could realize at runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Static composition:\n",
      ">>> A_static:  (10,2):(16,4)\n",
      ">>> B_static:  (5,4):(1,5)\n",
      ">>> R_static:  (5,(2,2)):(16,(80,4))\n",
      ">?? Dynamic composition:\n",
      ">?? A_dynamic: (10,2):(16,4)\n",
      ">?? B_dynamic: (5,4):(1,5)\n",
      ">?? R_dynamic: ((5,1),(2,2)):((16,4),(80,4))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def composition_static_vs_dynamic_layout():\n",
    "    \"\"\"\n",
    "    Shows difference between static and dynamic composition results\n",
    "    \"\"\"\n",
    "    # Static version - using compile-time values\n",
    "    A_static = cute.make_layout(\n",
    "        (10, 2), \n",
    "        stride=(16, 4)\n",
    "    )\n",
    "    B_static = cute.make_layout(\n",
    "        (5, 4), \n",
    "        stride=(1, 5)\n",
    "    )\n",
    "    R_static = cute.composition(A_static, B_static)\n",
    "\n",
    "    # Static print shows compile-time info\n",
    "    print(\">>> Static composition:\")\n",
    "    print(\">>> A_static: \", A_static)\n",
    "    print(\">>> B_static: \", B_static)\n",
    "    print(\">>> R_static: \", R_static)\n",
    "\n",
    "    # Dynamic version - using runtime Int32 values\n",
    "    A_dynamic = cute.make_layout(\n",
    "        (cutlass.Int32(10), cutlass.Int32(2)),\n",
    "        stride=(cutlass.Int32(16), cutlass.Int32(4))\n",
    "    )\n",
    "    B_dynamic = cute.make_layout(\n",
    "        (cutlass.Int32(5), cutlass.Int32(4)),\n",
    "        stride=(cutlass.Int32(1), cutlass.Int32(5))\n",
    "    )\n",
    "    R_dynamic = cute.composition(A_dynamic, B_dynamic)\n",
    "    \n",
    "    # Dynamic printf shows runtime values\n",
    "    cute.printf(\">?? Dynamic composition:\")\n",
    "    cute.printf(\">?? A_dynamic: {}\", A_dynamic)\n",
    "    cute.printf(\">?? B_dynamic: {}\", B_dynamic)\n",
    "    cute.printf(\">?? R_dynamic: {}\", R_dynamic)\n",
    "\n",
    "composition_static_vs_dynamic_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-  By-mode Composition Example :\n",
    "\n",
    "By-mode composition allows us to apply composition operations to individual modes of a layout. This is particularly useful when you want to manipulate specific modes layout independently (e.g. rows and columns).\n",
    "\n",
    "In the context of CuTe, by-mode composition is achieved by using a `Tiler`, which can be a layout or a tuple of layouts. The leaves of the `Tiler` tuple specify how the corresponding mode of the target layout should be composed, allowing for sublayouts to be treated independently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout A: (?,(?,?)):(?,(?,?))\n",
      ">>> Tiler: (3, 8)\n",
      ">>> By-mode Composition Result: (3,(?,?)):(?,(?,?))\n",
      ">?? Layout A: (12,(4,8)):(59,(13,1))\n",
      ">?? Tiler: (3,8)\n",
      ">?? By-mode Composition Result: (3,(4,2)):(59,(13,1))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def bymode_composition_example():\n",
    "    \"\"\"\n",
    "    Demonstrates by-mode composition using a tiler\n",
    "    \"\"\"\n",
    "    # Define the original layout A\n",
    "    A = cute.make_layout(\n",
    "        (cutlass.Int32(12), (cutlass.Int32(4), cutlass.Int32(8))), \n",
    "        stride=(cutlass.Int32(59), (cutlass.Int32(13), cutlass.Int32(1)))\n",
    "    )\n",
    "\n",
    "    # Define the tiler for by-mode composition\n",
    "    tiler = (3, 8) # Apply 3:1 to mode-0 and 8:1 to mode-1\n",
    "\n",
    "    # Apply by-mode composition\n",
    "    result = cute.composition(A, tiler)\n",
    "\n",
    "    # Print static and dynamic information\n",
    "    print(\">>> Layout A:\", A)\n",
    "    cute.printf(\">?? Layout A: {}\", A)\n",
    "    print(\">>> Tiler:\", tiler)\n",
    "    cute.printf(\">?? Tiler: {}\", tiler)\n",
    "    print(\">>> By-mode Composition Result:\", result)\n",
    "    cute.printf(\">?? By-mode Composition Result: {}\", result)\n",
    "\n",
    "bymode_composition_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Division (Splitting into Tiles)\n",
    "\n",
    "The Division operation in CuTe is used to split a layout into tiles, which is particularly useful for partitioning data across threads or memory hierarchies.\n",
    "\n",
    "#### Examples :\n",
    "\n",
    "- Logical divide :\n",
    "\n",
    "When applied to two Layouts, `logical_divide` splits a layout into two modes -- the first mode contains the elements pointed to by the tiler, and the second mode contains the remaining elements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (4,2,3):(2,1,8)\n",
      ">>> Tiler : 4:2\n",
      ">>> Logical Divide Result: ((2,2),(2,3)):((4,1),(2,8))\n",
      ">?? Logical Divide Result: ((2,2),(2,3)):((4,1),(2,8))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def logical_divide_1d_example():\n",
    "    \"\"\"\n",
    "    Demonstrates 1D logical divide\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((4, 2, 3), stride=(2, 1, 8))  # (4,2,3):(2,1,8)\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = cute.make_layout(4, stride=2)  # Apply to layout 4:2\n",
    "    \n",
    "    # Apply logical divide\n",
    "    result = cute.logical_divide(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Logical Divide Result:\", result)\n",
    "    cute.printf(\">?? Logical Divide Result: {}\", result)\n",
    "\n",
    "logical_divide_1d_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When applied to a Layout and a `Tiler` tuple, `logical_divide` applies itself to the leaves of the `Tiler`and the corresponding mode of the target Layout. This means that the sublayouts are split independently according to the layouts within the `Tiler`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (9,(4,8)):(59,(13,1))\n",
      ">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc95a4ca7b0>, <cutlass.cute.core._Layout object at 0x7fc958160f50>)\n",
      ">>> Logical Divide Result: ((3,3),((2,4),(2,2))):((177,59),((13,2),(26,1)))\n",
      ">?? Logical Divide Result: ((3,3),((2,4),(2,2))):((177,59),((13,2),(26,1)))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def logical_divide_2d_example():\n",
    "    \"\"\"\n",
    "    Demonstrates 2D logical divide :\n",
    "    Layout Shape : (M, N, L, ...)\n",
    "    Tiler Shape  : <TileM, TileN>\n",
    "    Result Shape : ((TileM,RestM), (TileN,RestN), L, ...)\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3\n",
    "             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)\n",
    "    \n",
    "    # Apply logical divide\n",
    "    result = cute.logical_divide(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Logical Divide Result:\", result)\n",
    "    cute.printf(\">?? Logical Divide Result: {}\", result)\n",
    "\n",
    "logical_divide_2d_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Zipped, tiled, and flat divide are flavors of `logical_divide` that potentially rearrange modes into more convenient forms.\n",
    "\n",
    "- Zipped Divide :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (9,(4,8)):(59,(13,1))\n",
      ">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc95a4ca7b0>, <cutlass.cute.core._Layout object at 0x7fc9581611f0>)\n",
      ">>> Zipped Divide Result: ((3,(2,4)),(3,(2,2))):((177,(13,2)),(59,(26,1)))\n",
      ">?? Zipped Divide Result: ((3,(2,4)),(3,(2,2))):((177,(13,2)),(59,(26,1)))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def zipped_divide_example():\n",
    "    \"\"\"\n",
    "    Demonstrates zipped divide :\n",
    "    Layout Shape : (M, N, L, ...)\n",
    "    Tiler Shape  : <TileM, TileN>\n",
    "    Result Shape : ((TileM,TileN), (RestM,RestN,L,...))\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3\n",
    "             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)\n",
    "    \n",
    "    # Apply zipped divide\n",
    "    result = cute.zipped_divide(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Zipped Divide Result:\", result)\n",
    "    cute.printf(\">?? Zipped Divide Result: {}\", result)\n",
    "\n",
    "zipped_divide_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Tiled Divide :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (9,(4,8)):(59,(13,1))\n",
      ">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc9581610d0>, <cutlass.cute.core._Layout object at 0x7fc958161070>)\n",
      ">>> Tiled Divide Result: ((3,(2,4)),3,(2,2)):((177,(13,2)),59,(26,1))\n",
      ">?? Tiled Divide Result: ((3,(2,4)),3,(2,2)):((177,(13,2)),59,(26,1))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def tiled_divide_example():\n",
    "    \"\"\"\n",
    "    Demonstrates tiled divide :\n",
    "    Layout Shape : (M, N, L, ...)\n",
    "    Tiler Shape  : <TileM, TileN>\n",
    "    Result Shape : ((TileM,TileN), RestM, RestN, L, ...)\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3\n",
    "             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)\n",
    "    \n",
    "    # Apply tiled divide\n",
    "    result = cute.tiled_divide(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Tiled Divide Result:\", result)\n",
    "    cute.printf(\">?? Tiled Divide Result: {}\", result)\n",
    "\n",
    "tiled_divide_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Flat Divide :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (9,(4,8)):(59,(13,1))\n",
      ">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc958161430>, <cutlass.cute.core._Layout object at 0x7fc9581610d0>)\n",
      ">>> Flat Divide Result: (3,(2,4),3,(2,2)):(177,(13,2),59,(26,1))\n",
      ">?? Flat Divide Result: (3,(2,4),3,(2,2)):(177,(13,2),59,(26,1))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def flat_divide_example():\n",
    "    \"\"\"\n",
    "    Demonstrates flat divide :\n",
    "    Layout Shape : (M, N, L, ...)\n",
    "    Tiler Shape  : <TileM, TileN>\n",
    "    Result Shape : (TileM, TileN, RestM, RestN, L, ...)\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3\n",
    "             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)\n",
    "    \n",
    "    # Apply flat divide\n",
    "    result = cute.flat_divide(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Flat Divide Result:\", result)\n",
    "    cute.printf(\">?? Flat Divide Result: {}\", result)\n",
    "\n",
    "flat_divide_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Product (Reproducing a Tile)\n",
    "\n",
    "The Product operation in CuTe is used to reproduce one layout according to another layout. It creates a new layout where:\n",
    "- The first mode is the original layout A.\n",
    "- The second mode is a restrided layout B that points to the origin of a \"unique replication\" of A.\n",
    "\n",
    "This is particularly useful for repeating layouts of threads across a tile of data for creating \"repeat\" patterns.\n",
    "\n",
    "#### Examples\n",
    "\n",
    "- Logical Product :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (2,2):(4,1)\n",
      ">>> Tiler : 6:1\n",
      ">>> Logical Product Result: ((2,2),(2,3)):((4,1),(2,8))\n",
      ">?? Logical Product Result: ((2,2),(2,3)):((4,1),(2,8))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def logical_product_1d_example():\n",
    "    \"\"\"\n",
    "    Demonstrates 1D logical product\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((2, 2), stride=(4, 1))  # (2,2):(4,1)\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = cute.make_layout(6, stride=1)  # Apply to layout 6:1\n",
    "    \n",
    "    # Apply logical product\n",
    "    result = cute.logical_product(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Logical Product Result:\", result)\n",
    "    cute.printf(\">?? Logical Product Result: {}\", result)\n",
    "\n",
    "logical_product_1d_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Blocked and Raked Product :\n",
    "  \n",
    "  - Blocked Product: Combines the modes of A and B in a block-like fashion, preserving the semantic meaning of the modes by reassociating them after the product.\n",
    "  - Raked Product: Combines the modes of A and B in an interleaved or \"raked\" fashion, creating a cyclic distribution of the tiles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (2,5):(5,1)\n",
      ">>> Tiler : (3,4):(1,3)\n",
      ">>> Blocked Product Result: ((2,3),(5,4)):((5,10),(1,30))\n",
      ">>> Raked Product Result: ((3,2),(4,5)):((10,5),(30,1))\n",
      ">?? Blocked Product Result: ((2,3),(5,4)):((5,10),(1,30))\n",
      ">?? Raked Product Result: ((3,2),(4,5)):((10,5),(30,1))\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def blocked_raked_product_example():\n",
    "    \"\"\"\n",
    "    Demonstrates blocked and raked products\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((2, 5), stride=(5, 1))\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = cute.make_layout((3, 4), stride=(1, 3))\n",
    "    \n",
    "    # Apply blocked product\n",
    "    blocked_result = cute.blocked_product(layout, tiler=tiler)\n",
    "\n",
    "    # Apply raked product\n",
    "    raked_result = cute.raked_product(layout, tiler=tiler)\n",
    "    \n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Blocked Product Result:\", blocked_result)\n",
    "    print(\">>> Raked Product Result:\", raked_result)\n",
    "    cute.printf(\">?? Blocked Product Result: {}\", blocked_result)\n",
    "    cute.printf(\">?? Raked Product Result: {}\", raked_result)\n",
    "\n",
    "blocked_raked_product_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Zipped, tiled, and flat product :\n",
    "  \n",
    "  - Similar to divide operations, zipped, tiled, and flat product are flavors of `logical_product` that potentially rearrange modes into more convenient forms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Layout: (2,5):(5,1)\n",
      ">>> Tiler : (3,4):(1,3)\n",
      ">>> Zipped Product Result: ((2,5),(3,4)):((5,1),(10,30))\n",
      ">>> Tiled Product Result: ((2,5),3,4):((5,1),10,30)\n",
      ">>> Flat Product Result: (2,5,3,4):(5,1,10,30)\n",
      ">?? Zipped Product Result: ((2,5),(3,4)):((5,1),(10,30))\n",
      ">?? Tiled Product Result: ((2,5),3,4):((5,1),10,30)\n",
      ">?? Flat Product Result: (2,5,3,4):(5,1,10,30)\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def zipped_tiled_flat_product_example():\n",
    "    \"\"\"\n",
    "    Demonstrates zipped, tiled, and flat products\n",
    "    Layout Shape : (M, N, L, ...)\n",
    "    Tiler Shape  : <TileM, TileN>\n",
    "\n",
    "    zipped_product  : ((M,N), (TileM,TileN,L,...))\n",
    "    tiled_product   : ((M,N), TileM, TileN, L, ...)\n",
    "    flat_product    : (M, N, TileM, TileN, L, ...)\n",
    "    \"\"\"\n",
    "    # Define the original layout\n",
    "    layout = cute.make_layout((2, 5), stride=(5, 1))\n",
    "    \n",
    "    # Define the tiler\n",
    "    tiler = cute.make_layout((3, 4), stride=(1, 3))\n",
    "\n",
    "    # Apply zipped product\n",
    "    zipped_result = cute.zipped_product(layout, tiler=tiler)\n",
    "    \n",
    "    # Apply tiled product\n",
    "    tiled_result = cute.tiled_product(layout, tiler=tiler)\n",
    "    \n",
    "    # Apply flat product\n",
    "    flat_result = cute.flat_product(layout, tiler=tiler)\n",
    "\n",
    "    # Print results\n",
    "    print(\">>> Layout:\", layout)\n",
    "    print(\">>> Tiler :\", tiler)\n",
    "    print(\">>> Zipped Product Result:\", zipped_result)\n",
    "    print(\">>> Tiled Product Result:\", tiled_result)\n",
    "    print(\">>> Flat Product Result:\", flat_result)\n",
    "    cute.printf(\">?? Zipped Product Result: {}\", zipped_result)\n",
    "    cute.printf(\">?? Tiled Product Result: {}\", tiled_result)\n",
    "    cute.printf(\">?? Flat Product Result: {}\", flat_result)\n",
    "\n",
    "zipped_tiled_flat_product_example()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pythondsl_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
